package site.zsxeee.work.spring.oa.controller

import org.apache.shiro.SecurityUtils
import org.apache.shiro.authc.IncorrectCredentialsException
import org.apache.shiro.authc.UnknownAccountException
import org.apache.shiro.authc.UsernamePasswordToken
import org.apache.shiro.crypto.SecureRandomNumberGenerator
import org.apache.shiro.crypto.hash.Sha256Hash
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.stereotype.Controller
import org.springframework.web.bind.annotation.RequestMapping
import org.springframework.web.bind.annotation.RequestMethod
import org.springframework.web.bind.annotation.RequestParam
import org.springframework.web.servlet.ModelAndView
import site.zsxeee.work.spring.oa.database.entity.User
import site.zsxeee.work.spring.oa.database.repository.UserRepository
import site.zsxeee.work.spring.oa.util.AlertInfo
import site.zsxeee.work.spring.oa.util.MVFactory
import site.zsxeee.work.spring.oa.util.Validator
import java.io.ByteArrayOutputStream
import java.nio.charset.Charset
import java.util.*
import javax.imageio.ImageIO
import javax.servlet.http.HttpServletRequest
import javax.servlet.http.HttpServletResponse

@Controller
class LoginController {
    @Autowired
    private lateinit var userRepository: UserRepository

    @RequestMapping("login", method = [RequestMethod.GET])
    fun loginPage(request: HttpServletRequest) = MVFactory()
            .setTargetViewName("login")
            .setPageTitle("登录")
            .addObject("pageTitle", "登录")
            .make()
            .apply { validate(request, this) }
            .also {

                if (userRepository.count()==0L){
                    val password = Random().nextInt(99999999).toString()
                    userRepository.save(User().also {
                        it.userName = "admin"
                        it.nickName = "admin"
                        it.salt = SecureRandomNumberGenerator().nextBytes().toHex()
                        it.password = Sha256Hash(password, it.salt, 1024).toHex()
                    })

                    println("已自动生成默认用户：")
                    println("用户名：admin")
                    println("密码：$password")
                }
            }

    @RequestMapping("login", method = [RequestMethod.POST])
    fun loginRequest(@RequestParam userName: String, @RequestParam password: String, @RequestParam validator:String, request: HttpServletRequest, response: HttpServletResponse): ModelAndView {
        val mv = MVFactory().setPageTitle("登录").setTargetViewName("login")
        if(validator.trim().toUpperCase() != request.session.getAttribute("validatorCode")){
            return mv
                    .setAlert(AlertInfo(
                            AlertInfo.AlertLevel.ERROR,
                            "验证码错误！"
                    ))
                    .make()
                    .also{ validate(request, it) }
        }
        val subject = SecurityUtils.getSubject()
        try {
            subject.login(UsernamePasswordToken(userName, password, request.getParameter("remember") != null))
        } catch (e: UnknownAccountException) {
            return mv
                    .setAlert(AlertInfo(
                            AlertInfo.AlertLevel.ERROR,
                            "用户名不存在！"
                    ))
                    .make()
                    .also{ validate(request, it) }
        } catch (e:IncorrectCredentialsException){
            return mv
                    .setAlert(AlertInfo(
                            AlertInfo.AlertLevel.ERROR,
                            "用户名或密码错误！"
                    ))
                    .make()
                    .also{ validate(request, it) }
        }
        response.sendRedirect("/")
        return mv.setTargetViewName("index").make()
    }


    @RequestMapping("register", method = [RequestMethod.GET])
    fun registerPage() = MVFactory()
            .setTargetViewName("register")
            .setPageTitle("注册")
            .make()

    @RequestMapping("register", method = [RequestMethod.POST])
    fun registerRequest(@RequestParam userName: String, @RequestParam password: String, @RequestParam repeatPassword:String, response: HttpServletResponse): ModelAndView? {
        val mv = MVFactory().setTargetViewName("register").setPageTitle("注册")

        if (password != repeatPassword) {
            return mv.
                    setAlert(AlertInfo(
                            AlertInfo.AlertLevel.ERROR,
                            "两次密码不一致！"
                    )).make()
        }

        if (userRepository.findByUserName(userName) != null) {
            return mv.
                    setAlert(AlertInfo(
                            AlertInfo.AlertLevel.ERROR,
                            "用户名已存在！"
                    )).make()
        }

        val user = User().also {
            it.userName = userName
            it.nickName = userName
            it.salt = SecureRandomNumberGenerator().nextBytes().toHex()
            it.password = Sha256Hash(password, it.salt, 1024).toHex()
        }

        try {
            userRepository.save(user)
        } catch (e: Exception) {
            return mv.
                    setAlert(AlertInfo(
                            AlertInfo.AlertLevel.ERROR,
                            "注册失败！"
                    )).make()
        }

        response.sendRedirect("/login")
        return null
    }

    private fun validate(request: HttpServletRequest, mv: ModelAndView){
        val validator = Validator()
        val os = ByteArrayOutputStream()
        ImageIO.write(validator.codePic, "gif", os)
        val base64 = Base64.getMimeEncoder().encode(os.toByteArray())
        mv.addObject("validatorImg", "data:image/gif;base64," + base64.toString(Charset.forName("UTF-8")))
        request.session.setAttribute("validatorCode", validator.code)
    }
}